package day05;

import beans.SensorReading;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.java.StreamTableEnvironment;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.types.Row;

/**
 * Flink Table API 与 SQL —— 自定义聚合函数
 * <p>
 * 阿里云文档：https://help.aliyun.com/document_detail/69553.html
 *
 * @author lvbingbing
 * @date 2022-01-22 13:16
 */
public class FlinkTableApi11 {
    public static void main(String[] args) throws Exception {
        // 1、创建可执行环境
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);
        // 2、从文件中读取数据
        DataStream<SensorReading> dataStream = env.readTextFile("input/sensor.txt")
                .map((String e) -> {
                    String[] split = e.split(",");
                    return new SensorReading(split[0], new Long(split[1]), new Double(split[2]));
                });
        // 3、创建表环境
        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
        // 4、将流数据转换成表
        Table sensorTable = tableEnv.fromDataStream(dataStream, "id, timestamp as ts, temperature as temp");
        // 5、自定义聚合函数
        studyUserDefinedAggregateFunction(sensorTable, tableEnv);
        // 6、触发程序执行
        env.execute();
    }

    /**
     * 自定义聚合函数，求平均值
     *
     * @param sensorTable 数据表
     * @param tableEnv    表执行环境
     */
    private static void studyUserDefinedAggregateFunction(Table sensorTable, StreamTableEnvironment tableEnv) {
        // 1、创建自定义聚合函数对象
        AverageAggregateFunction averageAggregateFunction = new AverageAggregateFunction();
        // 2、注册聚合函数
        tableEnv.registerFunction("averageFunc", averageAggregateFunction);
        // 3、tableApi
        Table aggregatedTable = sensorTable.groupBy("id")
                .aggregate("averageFunc(temp) as avgTemp")
                .select("id, avgTemp");
        DataStream<Tuple2<Boolean, Row>> tuple2DataStream = tableEnv.toRetractStream(aggregatedTable, Row.class);
        tuple2DataStream.print("aggregatedResult");
        // 4、sql
        tableEnv.createTemporaryView("sensor", sensorTable);
        String sql = "select id, averageFunc(temp) as avgTemp from sensor group by id";
        Table sqlQueryResult = tableEnv.sqlQuery(sql);
        DataStream<Tuple2<Boolean, Row>> tuple2DataStream1 = tableEnv.toRetractStream(sqlQueryResult, Row.class);
        tuple2DataStream1.print("sqlQueryResult");
    }

    /**
     * 实现自定义聚合函数，求平均值
     */
    public static class AverageAggregateFunction extends AggregateFunction<Double, Tuple2<Double, Integer>> {

        @Override
        public Double getValue(Tuple2<Double, Integer> accumulator) {
            return accumulator.f0 / accumulator.f1;
        }

        @Override
        public Tuple2<Double, Integer> createAccumulator() {
            return new Tuple2<>(0.0, 0);
        }

        /**
         * 必须要实现一个 accumulate 方法，来数据之后更新状态
         *
         * @param accumulator 累加器
         * @param temp        当前数据对应的温度值
         */
        @SuppressWarnings(value = "unused")
        public void accumulate(Tuple2<Double, Integer> accumulator, Double temp) {
            accumulator.f0 += temp;
            accumulator.f1 += 1;
        }
    }
}
